resources:
  accelerators: L4:2
  memory: 64+

num_nodes: 2

workdir: .

setup: |
  conda activate ray
  if [ $? -ne 0 ]; then
    conda create -n ray python=3.10 -y
    conda activate ray
  fi

  pip install "ray[train]"
  pip install tqdm
  pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

run: |
  sudo chmod 777 -R /var/tmp
  conda activate ray
  head_ip=`echo "$SKYPILOT_NODE_IPS" | head -n1`
  num_nodes=`echo "$SKYPILOT_NODE_IPS" | wc -l`
  if [ "$SKYPILOT_NODE_RANK" == "0" ]; then
    ps aux | grep ray | grep 6379 &> /dev/null || ray start --head  --disable-usage-stats --port 6379
    sleep 5
    python train.py --num-workers $num_nodes
  else
    sleep 5
    ps aux | grep ray | grep 6379 &> /dev/null || ray start --address $head_ip:6379 --disable-usage-stats
  fi
